{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "a673b451",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "\n",
    "def corr2d(X,K):    #X为输入，K为核矩阵\n",
    "    h,w=K.shape    #h得到K的行数，w得到K的列数\n",
    "    Y=torch.zeros((X.shape[0]-h+1,X.shape[1]-w+1))  #用0初始化输出矩阵Y\n",
    "    for i in range(Y.shape[0]):   #卷积运算\n",
    "        for j in range(Y.shape[1]):\n",
    "          Y[i,j]=(X[i:i+h,j:j+w]*K).sum()\n",
    "    return Y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "85e0f332",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[19., 25.],\n",
       "        [37., 43.]])"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#样例点测试\n",
    "X=torch.tensor([[0,1,2],[3,4,5],[6,7,8]])\n",
    "K=torch.tensor([[0,1],[2,3]])\n",
    "corr2d(X,K)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "3b9a8de1",
   "metadata": {},
   "outputs": [],
   "source": [
    "#实现二维卷积层\n",
    "class Conv2d(nn.Module):\n",
    "    def _init_(self,kernel_size):\n",
    "        super()._init_()\n",
    "        self.weight=nn.Parameter(torch.rand(kerner_size))\n",
    "        self.bias=nn.Parameter(torch.zeros(1))\n",
    "    def forward(self,x):\n",
    "        return corr2d(x,self.weight)+self.bias "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "555f2489",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1., 1., 0., 0., 0., 0., 1., 1.],\n",
       "        [1., 1., 0., 0., 0., 0., 1., 1.],\n",
       "        [1., 1., 0., 0., 0., 0., 1., 1.],\n",
       "        [1., 1., 0., 0., 0., 0., 1., 1.],\n",
       "        [1., 1., 0., 0., 0., 0., 1., 1.],\n",
       "        [1., 1., 0., 0., 0., 0., 1., 1.]])"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X=torch.ones((6,8))\n",
    "X[:,2:6]=0\n",
    "X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "859d2935",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0., -1.,  0.,  0.,  0.,  1.,  0.],\n",
       "        [ 0., -1.,  0.,  0.,  0.,  1.,  0.],\n",
       "        [ 0., -1.,  0.,  0.,  0.,  1.,  0.],\n",
       "        [ 0., -1.,  0.,  0.,  0.,  1.,  0.],\n",
       "        [ 0., -1.,  0.,  0.,  0.,  1.,  0.],\n",
       "        [ 0., -1.,  0.,  0.,  0.,  1.,  0.]])"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "K=torch.tensor([[-1,1]])  #这个K只能检测垂直边缘\n",
    "Y=corr2d(X,K)\n",
    "Y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "926cb124",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0., 0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0., 0.]])"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "corr2d(X.t(),K)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "2eb05498",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "batch 2, loss 3.852\n",
      "batch 4, loss 1.126\n",
      "batch 6, loss 0.386\n",
      "batch 8, loss 0.145\n",
      "batch 10, loss 0.057\n"
     ]
    }
   ],
   "source": [
    "conv2d = nn.Conv2d(1, 1, kernel_size=(1, 2), bias=False)\n",
    "\n",
    "X = X.reshape((1, 1, 6, 8))\n",
    "Y = Y.reshape((1, 1, 6, 7))\n",
    "\n",
    "for i in range(10):\n",
    "    Y_hat = conv2d(X)\n",
    "    l = (Y_hat - Y)**2\n",
    "    conv2d.zero_grad()\n",
    "    l.sum().backward()\n",
    "    conv2d.weight.data[:] -= 3e-2 * conv2d.weight.grad\n",
    "    if (i + 1) % 2 == 0:\n",
    "        print(f'batch {i+1}, loss {l.sum():.3f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "a626a495",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-1.0173,  0.9685]])"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "conv2d.weight.data.reshape((1, 2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5722b641",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:py36torch040]",
   "language": "python",
   "name": "conda-env-py36torch040-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
